import torch
import cv2 as cv
from model.zhnnet import ZhnNet, image_generate_loc


model = ZhnNet()
model.load_state_dict(torch.load('zhnnet.pth'))
model.eval()
cap = cv.VideoCapture(0)
cap.set(cv.CAP_PROP_FRAME_WIDTH, 1280)
cap.set(cv.CAP_PROP_FRAME_HEIGHT, 480)
print('Network loading complete.')
while cap.isOpened():
    ret, image_origin = cap.read()
    if not ret:
        break
# while True:
#     image = cv.imread('E:/dataset/zhnface.png')
    image = cv.copyMakeBorder(image_origin, 16, 16, 0, 0, cv.BORDER_CONSTANT, value=(255, 255, 255))
    imgL = image[:, :640, :]
    imgR = image[:, 640:, :]
    imgL_x = imgL.transpose(2, 0, 1)/256
    imgR_x = imgR.transpose(2, 0, 1)/256
    with torch.no_grad():
        imgL_x = torch.tensor(imgL_x, dtype=torch.float32)
        imgR_x = torch.tensor(imgR_x, dtype=torch.float32)
        img = torch.stack((imgL_x, imgR_x), dim=0)
        predict = model(img)
    imgL = image_generate_loc(imgL, predict[0])
    imgR = image_generate_loc(imgR, predict[1])
    image[:, :640, :] = imgL
    image[:, 640:, :] = imgR
    image = image[16:496, :, :]
    cv.imshow('test', image)
    c = cv.waitKey(10)
    if c == ord('q'):
        break
    elif c == ord('s'):
        cv.imwrite('err.png', image_origin)
cap.release()
